from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from edward.models.random_variable import RandomVariable
from tensorflow.contrib.distributions import Distribution

try:
  from tensorflow.contrib.distributions import FULLY_REPARAMETERIZED
except Exception as e:
  raise ImportError("{0}. Your TensorFlow version is not supported.".format(e))


class distributions_Empirical(Distribution):
  """Empirical random variable.

  #### Examples

  ```python
  # 100 samples of a scalar
  x = Empirical(params=tf.zeros(100))
  assert x.shape == ()

  # 5 samples of a 2 x 3 matrix
  x = Empirical(params=tf.zeros([5, 2, 3]))
  assert x.shape == (2, 3)
  ```
  """
  def __init__(self,
               params,
               validate_args=False,
               allow_nan_stats=True,
               name="Empirical"):
    """Initialize an `Empirical` random variable.

    Args:
      params: tf.Tensor.
      Collection of samples. Its outer (left-most) dimension
      determines the number of samples.
    """
    parameters = locals()
    with tf.name_scope(name, values=[params]):
      with tf.control_dependencies([]):
        self._params = tf.identity(params, name="params")
        try:
          self._n = tf.shape(self._params)[0]
        except ValueError:  # scalar params
          self._n = tf.constant(1)

    super(distributions_Empirical, self).__init__(
        dtype=self._params.dtype,
        reparameterization_type=FULLY_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._params, self._n],
        name=name)

  @staticmethod
  def _param_shapes(sample_shape):
    return {"params": tf.convert_to_tensor(sample_shape, dtype=tf.int32)}

  @property
  def params(self):
    """Distribution parameter."""
    return self._params

  @property
  def n(self):
    """Number of samples."""
    return self._n

  def _batch_shape_tensor(self):
    return tf.constant([], dtype=tf.int32)

  def _batch_shape(self):
    return tf.TensorShape([])

  def _event_shape_tensor(self):
    return tf.shape(self.params)[1:]

  def _event_shape(self):
    return self.params.shape[1:]

  def _mean(self):
    return tf.reduce_mean(self.params, 0)

  def _stddev(self):
    # broadcasting n x shape - shape = n x shape
    r = self.params - self.mean()
    return tf.sqrt(tf.reduce_mean(tf.square(r), 0))

  def _variance(self):
    return tf.square(self.stddev())

  def _sample_n(self, n, seed=None):
    input_tensor = self.params
    if len(input_tensor.shape) == 0:
      input_tensor = tf.expand_dims(input_tensor, 0)
      multiples = tf.concat(
          [tf.expand_dims(n, 0), [1] * len(self.event_shape)], 0)
      return tf.tile(input_tensor, multiples)
    else:
      probs = tf.ones([self.n]) / tf.cast(self.n, dtype=tf.float32)
      cat = tf.contrib.distributions.Categorical(probs)
      indices = cat._sample_n(n, seed)
      tensor = tf.gather(input_tensor, indices)
      return tensor


# Generate random variable class similar to autogenerated ones from TensorFlow.
def __init__(self, *args, **kwargs):
  RandomVariable.__init__(self, *args, **kwargs)


_name = 'Empirical'
_candidate = distributions_Empirical
__init__.__doc__ = _candidate.__init__.__doc__
_globals = globals()
_params = {'__doc__': _candidate.__doc__,
           '__init__': __init__,
           'support': 'points'}
_globals[_name] = type(_name, (RandomVariable, _candidate), _params)
